"Bayesian agglomerative clustering engine" (Bowman, 2012) http://arxiv.org/pdf/1201.3867v5.pdf
Algorithm:
where
http://arxiv.org/pdf/1203.3468.pdf
Setup:
Bayesian hierarchical clustering approximately solves: $$ T^* = \underset{T\in \mathcal{H}}{\text{argmax }} p(\mathcal{D} | T)$$ i.e. search for the cluster tree $T^*$ that maximizes the marginal likelihood of the data, within some hypothesis class of cluster trees $\mathcal{H}$.
Key idea:
Hierarchical clustering typically considers only binary cluster trees, i.e. the merge operations considered in the greedy search always create a new node with 2 children, which results in spurious "cascading" behavior. We would instead like to consider arbitrary branching structure, so that we are not forced to introduce spurious orderings over indistinguishable subtrees.
Details:
Agglomerative clustering is greedy search over binary trees, iteratively finding the best pair of subtrees to join.
To generalize to arbitrary branching structures, we just have to consider a couple new merge operations at each iteration:
We can now greedily search over arbitrary branching structures.
In Bayesian clustering, our search is aimed at maximizing the marginal likelihood of a tree, $p(\mathcal{D} | T)$, and we may need tricks for efficiently computing, approximating, or lower-bounding this quantity.
Can we extend BACE to use alternate branching structures?
To do that, we'll need to create a model of hierarchies, not just a model of partitions.
Notes / observations:
Brainstorming alternate models:
Other related work:
"Constructing multi-resolution MSMs to elucidate RNA hairpin folding mechanisms" (Huang, Yao, Bowman, Sun, Guibas, Carlsson, Pande, 2010) http://psb.stanford.edu/psb-online/proceedings/psb10/huang.pdf
Abstract: "The procedure to build MSMs using SHC is as follows. (1) Partition the conformations into a large number of states, called microstates, according to their structural similarity. An approximate K-centers clustering algorithm[20] is used here as it gives states with approximately uniform size, resulting in a correlation between the population of each state and its density. (2) Split the microstates into n density levels ordered from high to low density (L= {L1, … Ln}) such that each level contains approximately the same number of conformations. Then construct super density level sets Si, where 12 1 ... i SLL L L = ∪ ∪∪ i− i . Thus each super density level contains all previous levels . (3) Within each super density level (Si), perform spectral clustering to group kinetically related microstates. Metastable regions are better separated at high density super levels, since most of the fuzzy microstates in the transition region are excluded at these levels. Now, build a graph representing the connectivity of the states across super density levels. Then generate gradient flows along the edges of the graph from low to high density levels. Each attraction node (or attractive basin) where the gradient flow ends is assigned to a new metastable state. (4) Assign every microstate not belonging to an attraction node to the metastable state it has the largest transition probability to. Thus we have a complete state decomposition for an MSM. Furthermore, this procedure may be repeated with different super density level sets to construct MSMs at different resolutions. The larger the number of super density levels, the finer the resolution and the larger the number of metastable states in the final MSM. 1 2... i SS S ⊆ ⊆"
"Asymptotic analysis of multiscale Markov chain" (Wei Zhang, 2016) http://arxiv.org/pdf/1512.08944v2.pdf
In [2]:
from msmbuilder.example_datasets import AlanineDipeptide
trajs = AlanineDipeptide().get().trajectories
from msmbuilder.cluster import MiniBatchKMedoids
kmeds = MiniBatchKMedoids(n_clusters=100,metric='rmsd',max_iter=10)
kmeds.fit(trajs)
Out[2]:
In [3]:
dtrajs = kmeds.transform(trajs)
In [4]:
import pyemma
In [5]:
msm = pyemma.msm.estimate_markov_model(dtrajs,1)
In [6]:
import matplotlib.pyplot as plt
%matplotlib inline
plt.imshow(msm.transition_matrix,cmap='Blues',interpolation='none')
plt.colorbar()
Out[6]:
In [7]:
lags=range(10)+range(10,101)[::10]
its = pyemma.msm.its(dtrajs,lags,nits=10)
In [8]:
pyemma.plots.plot_implied_timescales(its)
Out[8]:
In [9]:
C = msm.count_matrix_active?
In [ ]:
C = msm.count_matrix_active
In [10]:
C = msm.count_matrix_active
In [11]:
from scipy.stats import entropy as kl_div
def BACE_bayes_factor(C,i,j):
C_hat = C.sum(1)
T = C / C_hat[:,None]
q = (C_hat[i]*T[i] + C_hat[j]*T[j]) / (C_hat[i]+C_hat[j])
return C_hat[i] * kl_div(T[i],q) + C_hat[j] * kl_div(T[j],q)
In [13]:
import numpy as np
n = len(C)
mergeability = np.zeros((n,n))
mask = np.zeros((n,n))
for i in range(n):
for j in range(n):
if i!=j and C[i,j]>0 and C[j,i] > 0:
mergeability[i,j] = -BACE_bayes_factor(C,i,j)
mask[i,j] = 1
In [14]:
np.max(mergeability[mergeability!=0])
Out[14]:
In [15]:
plt.imshow(-mergeability,cmap='Blues',interpolation='none')
plt.colorbar()
Out[15]:
In [16]:
(mergeability[mergeability!=0]).flatten()
Out[16]:
In [17]:
x = np.random.rand(10,2)
x.sum(1)
Out[17]:
In [18]:
x / x.sum(1)[:,None]
Out[18]:
In [19]:
def fractional_metastability(T):
return np.trace(T)/len(T)
def product_of_fractional_metastabilities(Ts):
return np.prod([fractional_metastability(T) for T in Ts])
In [20]:
class RoseTree(object):
# says which indices are lumped at which level
def __init__(self,children):
self.children = children
def leaves(self):
leaf_list = []
for child in self.children:
if type(child)==RoseTree:
for leaf in child.leaves():
leaf_list.append(leaf)
else:
leaf_list.append(child)
return leaf_list
def join(a,b):
return RoseTree([a,b])
def absorb(a,b):
a.children.append(b)
def collapse(a,b):
return RoseTree(a.children+b.children)
In [21]:
T_0 = RoseTree([RoseTree([i]) for i in range(n)])
In [22]:
T = T_0.children[0]
while len(T_0.children)>0:
T = join(T_0.children.pop(),T)
In [23]:
np.trace(msm.transition_matrix)/n
Out[23]:
In [24]:
from sklearn.cluster import SpectralBiclustering
bic = SpectralBiclustering(6)
bic.fit(msm.transition_matrix)
In [25]:
msm.transition_matrix.shape
Out[25]:
In [26]:
def cg_T(microstate_T, microstate_pi, cg_map):
''' Coarse-grain a microstate transition matrix by applying cg_map
Parameters
----------
microstate_T : (N,N), array-like, square
microstate transition matrix
microstate_pi : (N,), array-like
microstate stationary distribution
cg_map : (N,), array-like
assigns each microstate i to a macrostate cg_map[i]
Returns
-------
T : numpy.ndarray, square
macrostate transition matrix
'''
n_macrostates = np.max(cg_map)+1
n_microstates = len(microstate_T)
# compute macrostate stationary distribution
macrostate_pi = np.zeros(n_macrostates)
for i in range(n_microstates):
macrostate_pi[cg_map[i]] += microstate_pi[i]
macrostate_pi /= np.sum(macrostate_pi)
# accumulate macrostate transition matrix
T = np.zeros((n_macrostates,n_macrostates))
for i in range(n_microstates):
for j in range(n_microstates):
T[cg_map[i],cg_map[j]] += microstate_pi[i] * microstate_T[i,j]
# normalize
for a in range(n_macrostates):
T[a] /= macrostate_pi[a]
return T
In [27]:
T = msm.transition_matrix
pi = msm.stationary_distribution
macro_T = cg_T(T,pi,bic.row_labels_)
In [28]:
fractional_metastability(macro_T)
Out[28]:
In [29]:
fractional_metastability(T[bic.row_labels_==0][:,bic.row_labels_==0])
Out[29]:
In [30]:
fractional_metastability(T[bic.row_labels_==1][:,bic.row_labels_==1])
Out[30]:
In [31]:
fractional_metastability(T[bic.row_labels_==2][:,bic.row_labels_==2])
Out[31]:
In [32]:
np.trace(T)
Out[32]:
In [33]:
np.prod([np.trace(T[bic.row_labels_==i][:,bic.row_labels_==i]) for i in range(3)])
Out[33]:
In [34]:
dumb_labels = np.zeros(n)
dumb_labels[n/2:] = 1
In [35]:
np.prod([np.trace(T[dumb_labels==i][:,dumb_labels==i]) for i in range(2)])
Out[35]:
In [36]:
np.trace(T[dumb_labels==0][:,dumb_labels==0])
Out[36]:
In [37]:
np.trace(T[dumb_labels==1][:,dumb_labels==1])
Out[37]:
In [38]:
def plot_contiguous(T,mapping):
sorted_inds = np.array(sorted(range(len(T)),key=lambda i:mapping[i]))
plt.imshow(T[sorted_inds][:,sorted_inds],interpolation='none',cmap='Blues')
plt.colorbar()
In [39]:
plot_contiguous(T,dumb_labels)
In [40]:
plot_contiguous(T,bic.row_labels_)
In [41]:
submatrix = lambda i:T[bic.row_labels_==i][:,bic.row_labels_==i]
T_0 = submatrix(5)
In [42]:
bic2 = SpectralBiclustering(2)
bic2.fit(T_0)
bic2.row_labels_
Out[42]:
In [43]:
plot_contiguous(T_0,bic2.row_labels_)
In [44]:
macro_Ts = []
Ns = range(2,30)
for i in Ns:
bic = SpectralBiclustering(i)
bic.fit(msm.transition_matrix)
macro_T = cg_T(T,pi,bic.row_labels_)
macro_Ts.append(macro_T)
In [45]:
plt.plot(Ns,[np.trace(t) for t in macro_Ts])
Out[45]:
In [46]:
plt.plot(Ns,[np.trace(t)/len(t) for t in macro_Ts])
Out[46]:
In [47]:
plt.plot(Ns,[np.trace(t**2) / (len(t)) for t in macro_Ts])
Out[47]:
In [56]:
plt.plot(Ns,[np.trace(t)**2 / (len(t)) for t in macro_Ts])
Out[56]:
In [48]:
plt.plot(Ns,[np.log(np.trace(t)) /(len(t)) for t in macro_Ts])
Out[48]:
In [49]:
plt.plot(Ns,[np.trace(t)/(len(t)**2) for t in macro_Ts])
Out[49]:
In [50]:
np.linalg.norm(T-np.diag(np.diag(T)))
Out[50]:
In [51]:
np.trace(T)
Out[51]:
In [52]:
np.sum(msm.count_matrix_active)
Out[52]:
In [53]:
np.trace(msm.count_matrix_active)/np.sum(msm.count_matrix_active)
Out[53]:
In [54]:
np.sum(msm.count_matrix_active[dumb_labels==0][:,dumb_labels==0])
Out[54]:
In [55]:
msm.count_matrix_active
Out[55]:
In [63]:
likelihood_unnorm = np.prod(msm.transition_matrix**msm.count_matrix_active)
likelihood_unnorm
Out[63]:
In [64]:
np.min(msm.transition_matrix**msm.count_matrix_active)
Out[64]:
In [82]:
msm = pyemma.msm.BayesianMSM(1,nsamples=1000)
msm.fit(dtrajs)
In [83]:
mats = np.array([m.transition_matrix for m in msm.samples[:10]])
plt.imshow(mats.mean(0),interpolation='none',cmap='Blues')
plt.title('Mean')
plt.colorbar()
plt.figure()
stdev = mats.std(0)
plt.imshow(stdev,interpolation='none',cmap='Blues')
plt.title('Standard deviation')
plt.colorbar()
plt.figure()
stderr = mats.std(0)/np.sqrt(len(mats))
plt.imshow(stderr,interpolation='none',cmap='Blues')
plt.title('Standard error')
plt.colorbar()
Out[83]:
In [86]:
for sample in msm.samples:
plt.plot(sample.pi)
In [87]:
stat_dists=np.array([sample.pi for sample in msm.samples])
In [90]:
stat_dists.mean(0).shape
Out[90]:
In [93]:
plt.plot(stat_dists.mean(0))
stderr = stat_dists.std(0)#/msm.nsamples
plt.fill_between(np.arange(msm.nstates),stat_dists.mean(0)+stderr,stat_dists.mean(0)-stderr,alpha=0.4)
Out[93]:
In [95]:
np.argmax(stat_dists.std(0)/stat_dists.mean(0))
Out[95]:
In [98]:
plt.plot(stat_dists.std(0))
Out[98]:
In [99]:
plt.scatter(stat_dists.std(0),stat_dists.mean(0))
Out[99]:
In [97]:
plt.plot(stat_dists.std(0)/stat_dists.mean(0))
Out[97]:
In [ ]: